Skip to content

Support dspy.Tool as input field type and dspy.ToolCall as output field type #8242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 2, 2025

Conversation

chenmoneygithub
Copy link
Collaborator

@chenmoneygithub chenmoneygithub commented May 20, 2025

The main goal is making dspy.Tool a valid input type with special format so that users can easily use dspy.Tool without using dspy.ReAct. dspy.ToolCall is just a thing wrapper over tool name and tool args to simplify the output field definition involving tool calls.

Sample usage:

import dspy

dspy.configure(
    lm=dspy.LM("openai/gpt-4o-mini", cache=False),
    # lm=dspy.LM("anthropic/claude-3-5-sonnet-20240620", cache=False),
)


class MySignature(dspy.Signature):
    question: str = dspy.InputField(description="The question to answer")
    tools: list[dspy.Tool] = dspy.InputField(description="The tools to use")
    answer: str = dspy.OutputField(description="The answer to the question, if no more tool calls are needed")
    tool_call: dspy.ToolCalls = dspy.OutputField(description="The tool call information, including name and arguments")


def get_weather(city: str) -> str:
    """Get the weather for a city"""
    return f"The weather in {city} is sunny"


tools = [dspy.Tool(get_weather)]
tools_dict = {t.name: t for t in tools}

predict = dspy.Predict(MySignature)

result = predict(question="What is the weather in Paris?", tools=tools)
if result.tool_call:
    print(f"Executing tool: {result.tool_call.name}")
    print(f"Tool result: {tools_dict[result.tool_call.name](**result.tool_call.args)}")

dspy.inspect_history()

With sample output (not including history):

Executing tool: get_weather
Tool result: The weather in Paris is sunny

@chenmoneygithub
Copy link
Collaborator Author

@okhat @TomeHirata I am going to add some adapter change to this PR, will ping you for review after that is done.

@@ -181,6 +200,44 @@ def __str__(self):
return f"{self.name}{desc} {arg_desc}"


class ToolCalls(BaseType):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: Why do we make ToolCalls a first citizen rather than ToolCall? Other frameworks basically define ToolCall first (langchain) and then treat tool call response as list[ToolCall].

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you spotted the weirdo, and this is a design after multiple edits. I started with dspy.Tool, but then realized we are leaving the output field in free forms. For example, users might do:

class QAWithMultipleToolCall(dspy.Signature):
    question: str = dspy.InputField()
    tools: list[dspy.Tool] = dspy.InputField()
    answer: str = dspy.OutputField()
    tool_call_1: dspy.ToolCall = dspy.OutputField()
    tool_call_2: dspy.ToolCall = dspy.OutputField()

while we expect them to do:

class QAWithToolCall(dspy.Signature):
    question: str = dspy.InputField()
    tools: list[dspy.Tool] = dspy.InputField()
    answer: str = dspy.OutputField()
    tool_calls: list[dspy.ToolCall] = dspy.OutputField()

The bad thing about QAWithMultipleToolCall is when we use the native function calling (see JSONAdapter in this PR), we have no clue how to write the tool calls back to the output field. With a regulation that the output field must be dspy.ToolCalls, in native function calling use case, we can locate the output field and write the tool call info there.

For some more context, there is a caveat that OAI models are bad at outputting dict, so the current JSONAdapter is broken with tool calling. For a quick test, you can try using JSONAdapter + ReAct, in my testing that doesn't really work. Native function calling resolves this issue.

Copy link
Collaborator

@TomeHirata TomeHirata Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed response. I agree this is tricky. We could have a validation logic in dspy.Signature and disallow the usage of a single dspy.ToolCall field. Then, when we use the native function calling, how do we deal with the multiple tool call case? To my understanding, the native function calling does not provide any semantic grouping.

class QAWithToolCall(dspy.Signature):
    question: str = dspy.InputField()
    tools: list[dspy.Tool] = dspy.InputField()
    answer: str = dspy.OutputField()
    tier1_tool_calls: list[dspy.ToolCall] = dspy.OutputField()
    tier2_tool_calls: list[dspy.ToolCall] = dspy.OutputField()

value = self.parse(signature, text)
else:
value = {}
for field_name in signature.output_fields.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub qq - do we need to handle this tool-call specific logic here? I'm wondering that since it inherits from BaseType if we can add a parse function to Tool (and the BaseType interface) and port over all the logic there (similar to the format we use for custom types)? This could keep post-process generalizable, and we'd just add a check for if the signature includes a BaseType output field and parse accordingly. curious on thoughts here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output handling is a bit different from input handling, which we can possibly generalize.

However I would be cautious about doing it because we don't know yet if generalization makes sense here - for the ToolCalls case, we are reading from the LM response and write back to the output field that of type dspy.ToolCalls if native function calling is used, which may be completely different from the second output field we introduce.

@@ -20,18 +27,78 @@ def __init_subclass__(cls, **kwargs) -> None:
cls.format = with_callbacks(cls.format)
cls.parse = with_callbacks(cls.parse)

def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]:
def _call_preprocess(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar q for the preprocess too, could this instead be wrapped in Tool's format and make use of split_message_content_for_custom_types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for the tool calling, here we need to handle the special case of native function calling: https://platform.openai.com/docs/guides/function-calling?api-mode=chat, so we need to modify the LM call args in addition to the messages.

return None

def _get_tool_call_output_field_name(self, signature: Type[Signature]) -> bool:
for name, field in signature.output_fields.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when there are multiple fields with type ToolCalls?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then only one field will be populated with value. We can raise a warning when there are multiple ToolCalls field, I kinda doubt if users will do that though. We made it ToolCalls as an indicator that it has multiple ToolCalls.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had the same question for multimodal. in a single call (diff from n), i don't believe the models can produce multiple multimodal outputs with 2+ dspy.OutputFields (it'll raise the output exception we see sometimes)

could be different for ToolCalls tho, but might be safer to give that warning globally (maybe for the select ones in types?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Arnav. For cases where users define multiple output fields with a type that can only have a single field, such as tool calls or multi-modal, can we raise an exception when the invalid signature is created rather than warning since this is a wrong usage of signature?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub Since this PR is merged, can you follow up with another PR on this?

from dspy.primitives.program import Module
from dspy.primitives.tool import Tool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: do we have isort running on your local? Shall we include in our CI?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I have isort locally, and I remember there are a few files we don't want to sort, we can check after merging #7885

@okhat okhat merged commit 99a07e9 into stanfordnlp:main Jun 2, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants